{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "(basic-mat-mult)=\n",
        "# 简单的矩阵乘法\n",
        "\n",
        "**原作者**: [Thierry Moreau](https://homes.cs.washington.edu/~moreau/)\n",
        "\n",
        "在本教程构建在 {ref}`vta-get-started` 教程的基础上，并介绍在 VTA 上使用 TVM 工作流实现矩阵乘法所需的额外概念。\n",
        "\n",
        "## RPC 设置\n",
        "\n",
        "从编程 Pynq 的 FPGA 和构建它的 RPC 运行时开始，就像在 VTA 介绍性教程中做的那样。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "tags": [
          "remove-cell"
        ]
      },
      "outputs": [],
      "source": [
        "import set_env"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import tvm\n",
        "from tvm import te\n",
        "import vta\n",
        "import numpy as np\n",
        "from tvm import rpc\n",
        "from vta.testing import simulator\n",
        "\n",
        "# 从 3rdparty/vta-hw/config/vta_config.json 文件载入 VTA 参数\n",
        "env = vta.get_env()\n",
        "\n",
        "# 从操作系统环境中读取 Pynq RPC 主机 IP 地址和端口号\n",
        "host = os.environ.get(\"VTA_RPC_HOST\", \"192.168.2.99\")\n",
        "port = int(os.environ.get(\"VTA_RPC_PORT\", \"9091\"))\n",
        "\n",
        "# 在 Pynq 上配置 bitstream 和运行时系统，\n",
        "# 以匹配 vta_config.json 文件指定的 VTA 配置。\n",
        "if env.TARGET in [\"pynq\", \"de10nano\"]:\n",
        "    # 确保使用 RPC=1 编译 TVM\n",
        "    assert tvm.runtime.enabled(\"rpc\")\n",
        "    remote = rpc.connect(host, port)\n",
        "\n",
        "    # 重新配置 JIT runtime\n",
        "    vta.reconfig_runtime(remote)\n",
        "\n",
        "    # 用预编译的 VTA bitstream 程序编程 FPGA。\n",
        "    # 通过将路径传递给 bitstream 文件而不是 None，\n",
        "    # 您可以使用自定义 bitstream 编程 FPGA。\n",
        "    vta.program_fpga(remote, bitstream=None)\n",
        "\n",
        "# 在仿真模式下，在本地托管 RPC 服务器。\n",
        "elif env.TARGET in [\"sim\", \"tsim\"]:\n",
        "    remote = rpc.LocalSession()\n",
        "    if env.TARGET in [\"intelfocl\"]:\n",
        "        # program intelfocl aocx\n",
        "        vta.program_fpga(remote, bitstream=\"vta.bitstream\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 计算声明\n",
        "\n",
        "在这个例子中，描述了简单的矩阵乘法加法，它需要多个计算阶段，如下面的数据流图所示。\n",
        "\n",
        "- 首先描述存在于 main memory 中的输入张量 `A` 和 `B`。\n",
        "- 其次，需要声明中间张量 `A_buf` 和 `B_buf`，它们将存在于 VTA 的 on-chip buffers 中。有了这个额外的计算阶段，就可以显式地分阶段 cached 读和写。\n",
        "- 接着，描述了 `A_buf` 和 `B_buf` 上的矩阵乘法运算，以产生 product matrix `C_buf`。\n",
        "- 最后的运算是强制转换和复制回 DRAM，到结果张量 `C`。\n",
        "\n",
        "\n",
        "```{image} images/gemm_dataflow.png\n",
        ":align: center\n",
        "```"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 数据布局\n",
        "\n",
        "以平铺数据格式描述占位符张量 `A`和 `B`，以匹配 VTA 张量核心施加的数据布局要求。\n",
        "\n",
        "::::{admonition} 数据平铺（Tiling）\n",
        ":class: alert alert-info\n",
        "\n",
        "以加速器为 target 时，复杂性的来源是确保数据布局与加速器设计施加的布局相匹配。VTA 是围绕 *tensor core* 设计的，它在激活矩阵和权值矩阵之间执行周期的矩阵-矩阵运算，将结果矩阵添加到累加器矩阵，如下图所示。\n",
        "\n",
        "```{image} images/tensor_core.png\n",
        ":align: center\n",
        ":width: 480px\n",
        "```\n",
        "\n",
        "矩阵-矩阵乘法的维度在 `vta_config.json` 配置文件中指定。激活矩阵为 `(BATCH, BLOCK_IN)` 形状，权重矩阵为 `(BLOCK_OUT, BLOCK_IN)` 形状，由此推断，得到的输出矩阵为 `(BATCH, BLOCK_OUT)` 形状。因此，VTA 处理的输入和输出张量需要根据上述尺寸平铺。\n",
        "\n",
        "```{note}\n",
        "数学公式表示：\n",
        "\n",
        "$$\n",
        "\\begin{cases}\n",
        "X = \\begin{pmatrix}\n",
        "   x_1 \\\\\n",
        "   \\vdots \\\\\n",
        "   x_{BATCH}\n",
        "\\end{pmatrix}\\\\\n",
        "W = \\begin{pmatrix}\n",
        "   w_1 \\\\\n",
        "   \\vdots \\\\\n",
        "   w_{BLOCK\\_OUT}\n",
        "\\end{pmatrix}\n",
        "\\end{cases}\n",
        "$$\n",
        "\n",
        "其中 $x_i, w_j \\in \\mathbb{R}^{BLOCK\\_IN}$。\n",
        "\n",
        "故而\n",
        "\n",
        "$$\n",
        "O = X W^T = \\begin{pmatrix}\n",
        "\\braket{x_i, w_j}\n",
        "\\end{pmatrix}\n",
        "$$\n",
        "```\n",
        "\n",
        "下图显示了数据平铺对最初形状为 (4,8) 的矩阵的影响。平铺 (2,2) tile 保证了每个平铺内的数据是连续的。得到的平铺张量的形状是 (2, 4, 2, 2)。\n",
        "\n",
        "```{image} images/data_tiling.png\n",
        ":align: center\n",
        ":width: 480px\n",
        "```\n",
        "::::\n",
        "\n",
        "首先定义变量 `m`，`n`，`o` 来表示矩阵乘法的形状。这些变量分别是 `BLOCK_OUT`、`BLOCK_IN` 和 `BATCH` 张量维度上的乘法因子。默认情况下，配置文件将 `BATCH`、`BLOCK_IN` 和 `BLOCK_OUT` 分别设置为 1、16 和 16（将 `BATCH` 设置为 1 意味着计算构建块是向量-矩阵乘法）。\n",
        "\n",
        "```{admonition} 数据类型\n",
        ":class: alert alert-info\n",
        "\n",
        "重要的是，不仅要匹配 VTA 张量核心的内部 tile 维度，而且要匹配 VTA 期望的特定数据类型。VTA 目前只支持定点数据类型（fixed point data types），整数宽度在 `vta_config.json` 中指定。`INP_WIDTH` 和 `WGT_WIDTH` 分别用于激活和权重数据类型。此外，累加器数据类型整型宽度由 `ACC_WIDTH` 指定。\n",
        "```\n",
        "\n",
        "默认情况下，配置文件将 `INP_WIDTH` 和 `WGT_WIDTH` 设置为 8。累加器宽度 `ACC_WIDTH` 被设置为 32，以避免累加时溢出。结果是 `env.inp_dtype` 和 `env.wgt_dtype` 都是窄化的 8 位整型，而 `env.acc_dtype` 是标准的 32 位整型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# 输出通道因子 m - 总共 16 x 16 = 256 输出通道\n",
        "m = 16\n",
        "# 输入通道因子 n -总计 16x16=256 个输入通道\n",
        "n = 16\n",
        "# Batch 因子 o （使用单个 batch 推理）\n",
        "o = 1\n",
        "# tiled 据格式的 A 占位符张量\n",
        "A = te.placeholder((o, n, env.BATCH, env.BLOCK_IN), name=\"A\", dtype=env.inp_dtype)\n",
        "# tiled 据格式的 B 占位符张量\n",
        "B = te.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name=\"B\", dtype=env.wgt_dtype)\n",
        "# A copy buffer\n",
        "A_buf = te.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: A(*i), \"A_buf\")\n",
        "# B copy buffer\n",
        "B_buf = te.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: B(*i), \"B_buf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 矩阵乘法\n",
        "\n",
        "描述矩阵乘法的结果张量 `C`，还有另一个 compute 运算。compute 函数采用张量的形式，以及描述张量每个位置的计算规则的 lambda 函数。\n",
        "\n",
        "为了实现矩阵乘法，lambda 函数需要包含输入通道维度轴上的 reduction 公式。要创建 reduction 公式，可以使用 `te.reduce_axis` 声明 reduction axis，它在 reduction 的范围内。`te.sum` 接受要 reduction 的表达式和 reduction axes，以计算声明范围内所有 k 的值的和。\n",
        "\n",
        "注意 reduction 需要在 32 位 `env.acc_dtype` 累加器数据类型上执行。\n",
        "\n",
        "在这个阶段没有计算发生，因为只是声明应该如何进行计算。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Outer input feature reduction axis\n",
        "ko = te.reduce_axis((0, n), name=\"ko\")\n",
        "# Inner input feature reduction axis\n",
        "ki = te.reduce_axis((0, env.BLOCK_IN), name=\"ki\")\n",
        "# Describe the in-VTA matrix multiplication\n",
        "C_buf = te.compute(\n",
        "    (o, m, env.BATCH, env.BLOCK_OUT),\n",
        "    lambda bo, co, bi, ci: te.sum(\n",
        "        A_buf[bo, ko, bi, ki].astype(env.acc_dtype) * B_buf[co, ko, ci, ki].astype(env.acc_dtype),\n",
        "        axis=[ko, ki],\n",
        "    ),\n",
        "    name=\"C_buf\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Casting 结果\n",
        "\n",
        "计算完成后，需要将 VTA 计算的结果发送回 main memory。\n",
        "\n",
        "```{admonition} 内存存储的限制\n",
        ":class: alert alert-info\n",
        "\n",
        "VTA 的特点之一是，它只支持 DRAM 存储在窄化的 `env.inp_dtype` 数据类型格式。这使能够减少内存传输的数据占用，但也使能够将宽的累加器数据类型量化为与输入激活数据类型匹配的数据格式。这意味着在神经网络推理的背景下，激活某一层后的输出可以直接被下一层 consumed。\n",
        "```\n",
        "\n",
        "对窄化输入激活数据格式执行最后一次 typecast 运算。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Cast to output type, and send to main memory\n",
        "C = te.compute(\n",
        "    (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: C_buf(*i).astype(env.inp_dtype), name=\"C\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "本教程的计算声明部分到此结束。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 调度计算\n",
        "\n",
        "虽然上面几行描述了计算规则，但可以通过多种方式得到 `C`。TVM 要求用户提供名为 schedule 的计算实现。\n",
        "\n",
        "调度是对原始计算的一组变换，它在不影响正确性的情况下变换计算的实现。这个简单的 VTA 编程教程旨在演示基本的调度变换，将原始的调度映射到 VTA 硬件原语（primitive）。\n",
        "\n",
        "### 默认调度\n",
        "\n",
        "在构造了调度后，默认情况下，调度按以下方式计算 `C`："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">main</span>(A: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>), B: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>), C: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>)):\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>), <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        A_buf <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">256</span>], <span style=\"color: #BA2121\">&quot;int8&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        B_buf <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">65536</span>], <span style=\"color: #BA2121\">&quot;int8&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        C_buf <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">256</span>], <span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        A_buf_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">256</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>A_buf)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i1, i3 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
              "            cse_var_1: T<span style=\"color: #A2F; font-weight: bold\">.</span>int32 <span style=\"color: #A2F; font-weight: bold\">=</span> i1 <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i3\n",
              "            A_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">256</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>A<span style=\"color: #A2F; font-weight: bold\">.</span>data)\n",
              "            A_buf_1[cse_var_1] <span style=\"color: #A2F; font-weight: bold\">=</span> A_1[cse_var_1]\n",
              "        B_buf_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">65536</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>B_buf)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1, i2, i3 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
              "            cse_var_2: T<span style=\"color: #A2F; font-weight: bold\">.</span>int32 <span style=\"color: #A2F; font-weight: bold\">=</span> i0 <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i1 <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i2 <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i3\n",
              "            B_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">65536</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>B<span style=\"color: #A2F; font-weight: bold\">.</span>data)\n",
              "            B_buf_1[cse_var_2] <span style=\"color: #A2F; font-weight: bold\">=</span> B_1[cse_var_2]\n",
              "        C_buf_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">256</span>,), <span style=\"color: #BA2121\">&quot;int32&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>C_buf)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> co, ci <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
              "            C_buf_1[co <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ci] <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #008000\">0</span>\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> ko, ki <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
              "                cse_var_3: T<span style=\"color: #A2F; font-weight: bold\">.</span>int32 <span style=\"color: #A2F; font-weight: bold\">=</span> co <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ci\n",
              "                C_buf_1[cse_var_3] <span style=\"color: #A2F; font-weight: bold\">=</span> C_buf_1[cse_var_3] <span style=\"color: #A2F; font-weight: bold\">+</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Cast(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, A_buf_1[ko <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ki]) <span style=\"color: #A2F; font-weight: bold\">*</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Cast(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, B_buf_1[co <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ko <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ci <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ki])\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i1, i3 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
              "            cse_var_4: T<span style=\"color: #A2F; font-weight: bold\">.</span>int32 <span style=\"color: #A2F; font-weight: bold\">=</span> i1 <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i3\n",
              "            C_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">256</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>C<span style=\"color: #A2F; font-weight: bold\">.</span>data)\n",
              "            C_1[cse_var_4] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Cast(<span style=\"color: #BA2121\">&quot;int8&quot;</span>, C_buf_1[cse_var_4])\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Let's take a look at the generated schedule\n",
        "s = te.create_schedule(C.op)\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "虽然此调度有意义，但它不会编译到 VTA。为了获得正确的代码生成，需要应用调度原语和代码注解，将调度变换为可以直接 lower 至 VTA 硬件 intrinsic 的调度。这些包括：\n",
        "\n",
        "- DMA 复制运算，将全局作用域张量复制到局部作用域张量。\n",
        "- 用来做矩阵乘法的张量运算。\n",
        "\n",
        "### Buffer 作用域\n",
        "\n",
        "首先，设置 buffer 的作用域来告诉 TVM 这些 buffer 将存在于 VTA 的 on-chip SRAM cache 中。下面，告诉 TVM, `A_buf`，`B_buf`，`C_buf` 将分别存在于 VTA 的 on-chip 输入，权重和累加器（accumulator）内存中。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{admonition} VTA's On-Chip SRAMs\n",
        ":class: alert alert-info\n",
        "\n",
        "VTA 有三个不同的内存作用域，每个都对应于不同的片上 SRAM buffer。\n",
        "\n",
        "- `env.inp_scope`：输入 buffer，这是只读的 SRAM buffer，存储形状为 `(env.BATCH, env.BLOCK_IN)`，类型为 `env.inp_dtype` 的矩阵。输入 buffer 包含 $2^{LOG\\_INP\\_BUFF\\_SIZE}$ 个矩阵元素（在 `vta_config.json` 文件中指定）。\n",
        "- `env.wgt_scope`：权重 buffer，这是只读的 SRAM buffer，存储形状为 `(env.BLOCK_OUT, env.BLOCK_IN)`，类型为 `env.wgt_dtype` 的矩阵。权重 buffer 包含 $2^{LOG\\_WGT\\_BUFF\\_SIZE}$ 个矩阵元素。\n",
        "- `env.acc_scope`： Accumulator buffer，这是读/写 SRAM buffer，存储形状为 `(env.BATCH, env.BLOCK_OUT)`，类型为 `env.acc_dtype` 的累加矩阵。累加器 buffer 是 VTA 的通用寄存器文件：它既保存卷积和矩阵乘法的中间结果，也保存池化、batch normalization 和激活层的中间结果。累加器缓冲区包含 $2^{LOG\\_ACC\\_BUFF\\_SIZE}$ 个矩阵元素。\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "stage(C_buf, compute(C_buf, body=[T.reduce(T.comm_reducer(lambda x, y: x + y, [0]), source=[T.Cast(\"int32\", A_buf[bo, ko, bi, ki]) * T.Cast(\"int32\", B_buf[co, ko, ci, ki])], init=[], axis=[T.iter_var(ko, T.Range(0, 16), \"CommReduce\", \"\"), T.iter_var(ki, T.Range(0, 16), \"CommReduce\", \"\")], condition=T.bool(True), value_index=0)], axis=[T.iter_var(bo, T.Range(0, 1), \"DataPar\", \"\"), T.iter_var(co, T.Range(0, 16), \"DataPar\", \"\"), T.iter_var(bi, T.Range(0, 1), \"DataPar\", \"\"), T.iter_var(ci, T.Range(0, 16), \"DataPar\", \"\")], reduce_axis=[T.iter_var(ko, T.Range(0, 16), \"CommReduce\", \"\"), T.iter_var(ki, T.Range(0, 16), \"CommReduce\", \"\")], tag=, attrs={}))"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Set the intermediate tensor's scope to VTA's on-chip buffers\n",
        "s[A_buf].set_scope(env.inp_scope)\n",
        "s[B_buf].set_scope(env.wgt_scope)\n",
        "s[C_buf].set_scope(env.acc_scope)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### DMA 传输\n",
        "\n",
        "需要调度 DMA transfer 来将存在于 DRAM 中的数据移动到 VTA on-chip buffer。这可以使用 `compute_at` 调度原语实现，该原语将 buffer 的复制嵌套到执行矩阵乘法的计算循环中。\n",
        "\n",
        "插入 `dma_copy` pragmas 来指示编译器，复制运算将通过 DMA 批量执行，这在硬件加速器中很常见。最后，打印临时调度，观察将复制运算移动到矩阵乘法循环中的效果。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">main</span>(A: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>), B: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>), C: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>)):\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>), <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        C_buf <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">256</span>], <span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;local.acc_buffer&quot;</span>)\n",
              "        A_buf <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">16</span>], <span style=\"color: #BA2121\">&quot;int8&quot;</span>, <span style=\"color: #BA2121\">&quot;local.inp_buffer&quot;</span>)\n",
              "        B_buf <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">16</span>], <span style=\"color: #BA2121\">&quot;int8&quot;</span>, <span style=\"color: #BA2121\">&quot;local.wgt_buffer&quot;</span>)\n",
              "        C_buf_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">256</span>,), <span style=\"color: #BA2121\">&quot;int32&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>C_buf, scope<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;local.acc_buffer&quot;</span>, align<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">16</span>)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> co, ci <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
              "            C_buf_1[co <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ci] <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #008000\">0</span>\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> ko <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">16</span>):\n",
              "                i0 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>int32()\n",
              "                A_buf_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>A_buf, scope<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;local.inp_buffer&quot;</span>, align<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">16</span>)\n",
              "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(i0, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;DataPar&quot;</span>, <span style=\"color: #BA2121\">&quot;&quot;</span>), <span style=\"color: #BA2121\">&quot;pragma_dma_copy&quot;</span>, <span style=\"color: #008000\">1</span>):\n",
              "                    <span style=\"color: #008000; font-weight: bold\">for</span> i3 <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">16</span>):\n",
              "                        A_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">256</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>A<span style=\"color: #A2F; font-weight: bold\">.</span>data)\n",
              "                        A_buf_1[i3] <span style=\"color: #A2F; font-weight: bold\">=</span> A_1[ko <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i3]\n",
              "                i0_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>int32()\n",
              "                B_buf_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>B_buf, scope<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;local.wgt_buffer&quot;</span>, align<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">256</span>)\n",
              "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(i0_1, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;DataPar&quot;</span>, <span style=\"color: #BA2121\">&quot;&quot;</span>), <span style=\"color: #BA2121\">&quot;pragma_dma_copy&quot;</span>, <span style=\"color: #008000\">1</span>):\n",
              "                    <span style=\"color: #008000; font-weight: bold\">for</span> i3 <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">16</span>):\n",
              "                        B_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">65536</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>B<span style=\"color: #A2F; font-weight: bold\">.</span>data)\n",
              "                        B_buf_1[i3] <span style=\"color: #A2F; font-weight: bold\">=</span> B_1[co <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ko <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ci <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i3]\n",
              "                <span style=\"color: #008000; font-weight: bold\">for</span> ki <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">16</span>):\n",
              "                    cse_var_1: T<span style=\"color: #A2F; font-weight: bold\">.</span>int32 <span style=\"color: #A2F; font-weight: bold\">=</span> co <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> ci\n",
              "                    C_buf_1[cse_var_1] <span style=\"color: #A2F; font-weight: bold\">=</span> C_buf_1[cse_var_1] <span style=\"color: #A2F; font-weight: bold\">+</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Cast(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, A_buf_1[ki]) <span style=\"color: #A2F; font-weight: bold\">*</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Cast(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, B_buf_1[ki])\n",
              "        i0 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>int32()\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(i0, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;DataPar&quot;</span>, <span style=\"color: #BA2121\">&quot;&quot;</span>), <span style=\"color: #BA2121\">&quot;pragma_dma_copy&quot;</span>, <span style=\"color: #008000\">1</span>)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i1, i3 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
              "            cse_var_2: T<span style=\"color: #A2F; font-weight: bold\">.</span>int32 <span style=\"color: #A2F; font-weight: bold\">=</span> i1 <span style=\"color: #A2F; font-weight: bold\">*</span> <span style=\"color: #008000\">16</span> <span style=\"color: #A2F; font-weight: bold\">+</span> i3\n",
              "            C_1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">256</span>,), <span style=\"color: #BA2121\">&quot;int8&quot;</span>, data<span style=\"color: #A2F; font-weight: bold\">=</span>C<span style=\"color: #A2F; font-weight: bold\">.</span>data)\n",
              "            C_1[cse_var_2] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>Cast(<span style=\"color: #BA2121\">&quot;int8&quot;</span>, C_buf_1[cse_var_2])\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Move buffer copy into matrix multiply loop\n",
        "s[A_buf].compute_at(s[C_buf], ko)\n",
        "s[B_buf].compute_at(s[C_buf], ko)\n",
        "\n",
        "# Tag the buffer copies with the DMA pragma to insert a DMA transfer\n",
        "s[A_buf].pragma(s[A_buf].op.axis[0], env.dma_copy)\n",
        "s[B_buf].pragma(s[B_buf].op.axis[0], env.dma_copy)\n",
        "s[C].pragma(s[C].op.axis[0], env.dma_copy)\n",
        "\n",
        "# Let's take a look at the transformed schedule\n",
        "tvm.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 张量化\n",
        "\n",
        "调度 transformation 的最后一步是对调度应用 *tensorization*。张量化类似于向量化，但将这个概念扩展到了高维计算单元。因此，在声明数据布局输入占位符时，张量化会施加数据布局约束。我们已经以平铺的形式排列了张量，所以接下来需要做的是循环重新排序以适应张量化。\n",
        "\n",
        "在这里，选择将最外面的 reduction 轴移出。这表明首先遍历输入通道，然后是 batch 维度，最后是输出通道。最后，应用张量调度原语沿着最内层矩阵的矩阵乘法张量块的外轴张量。打印最终的调度，该调度已准备好由 VTA 运行时 JIT 编译器生成代码。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle\n",
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_sync\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">main</span>(A: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>), B: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>), C: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;int8&quot;</span>)):\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>), <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        vta <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>int32()\n",
              "        <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(vta, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;ThreadIndex&quot;</span>, <span style=\"color: #BA2121\">&quot;vta&quot;</span>), <span style=\"color: #BA2121\">&quot;coproc_scope&quot;</span>, <span style=\"color: #008000\">2</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(vta, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;ThreadIndex&quot;</span>, <span style=\"color: #BA2121\">&quot;vta&quot;</span>), <span style=\"color: #BA2121\">&quot;coproc_uop_scope&quot;</span>, <span style=\"color: #BA2121\">&quot;VTAPushGEMMOp&quot;</span>):\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>call_extern(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;VTAUopLoopBegin&quot;</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>)\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>uop_push(<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>)\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>call_extern(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;VTAUopLoopEnd&quot;</span>)\n",
              "            T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_push(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>)\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> ko <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">16</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(vta, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;ThreadIndex&quot;</span>, <span style=\"color: #BA2121\">&quot;vta&quot;</span>), <span style=\"color: #BA2121\">&quot;coproc_scope&quot;</span>, <span style=\"color: #008000\">1</span>):\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_pop(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>)\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>call_extern(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;VTALoadBuffer2D&quot;</span>, T<span style=\"color: #A2F; font-weight: bold\">.</span>tvm_thread_context(T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>command_handle()), A<span style=\"color: #A2F; font-weight: bold\">.</span>data, ko, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">2</span>)\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>call_extern(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;VTALoadBuffer2D&quot;</span>, T<span style=\"color: #A2F; font-weight: bold\">.</span>tvm_thread_context(T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>command_handle()), B<span style=\"color: #A2F; font-weight: bold\">.</span>data, ko, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">1</span>)\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_push(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>)\n",
              "            T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(vta, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;ThreadIndex&quot;</span>, <span style=\"color: #BA2121\">&quot;vta&quot;</span>), <span style=\"color: #BA2121\">&quot;coproc_scope&quot;</span>, <span style=\"color: #008000\">2</span>)\n",
              "            T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_pop(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(vta, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;ThreadIndex&quot;</span>, <span style=\"color: #BA2121\">&quot;vta&quot;</span>), <span style=\"color: #BA2121\">&quot;coproc_uop_scope&quot;</span>, <span style=\"color: #BA2121\">&quot;VTAPushGEMMOp&quot;</span>):\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>call_extern(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;VTAUopLoopBegin&quot;</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">1</span>)\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>uop_push(<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>)\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>call_extern(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;VTAUopLoopEnd&quot;</span>)\n",
              "            T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_push(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>)\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_push(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>)\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_pop(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>)\n",
              "        <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>attr(T<span style=\"color: #A2F; font-weight: bold\">.</span>iter_var(vta, <span style=\"color: #008000; font-weight: bold\">None</span>, <span style=\"color: #BA2121\">&quot;ThreadIndex&quot;</span>, <span style=\"color: #BA2121\">&quot;vta&quot;</span>), <span style=\"color: #BA2121\">&quot;coproc_scope&quot;</span>, <span style=\"color: #008000\">3</span>):\n",
              "            T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_dep_pop(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>)\n",
              "            T<span style=\"color: #A2F; font-weight: bold\">.</span>call_extern(<span style=\"color: #BA2121\">&quot;int32&quot;</span>, <span style=\"color: #BA2121\">&quot;VTAStoreBuffer2D&quot;</span>, T<span style=\"color: #A2F; font-weight: bold\">.</span>tvm_thread_context(T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>command_handle()), <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">4</span>, C<span style=\"color: #A2F; font-weight: bold\">.</span>data, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>)\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>tir<span style=\"color: #A2F; font-weight: bold\">.</span>vta<span style=\"color: #A2F; font-weight: bold\">.</span>coproc_sync()\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "s[C_buf].reorder(\n",
        "    ko, s[C_buf].op.axis[0], s[C_buf].op.axis[1], s[C_buf].op.axis[2], s[C_buf].op.axis[3], ki\n",
        ")\n",
        "s[C_buf].tensorize(s[C_buf].op.axis[2], env.gemm)\n",
        "\n",
        "# Let's take a look at the finalized schedule\n",
        "vta.lower(s, [A, B, C], simple_mode=True).show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "本教程的调度部分到此结束。\n",
        "\n",
        "## TVM 计算\n",
        "\n",
        "在完成了调度的指定之后，可以将它编译成 TVM 函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "[20:02:54] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64\n",
            "2025-07-27 20:02:54.515 INFO load_module /tmp/tmp86vp45_g/gemm.o\n"
          ]
        }
      ],
      "source": [
        "# Build GEMM VTA kernel\n",
        "my_gemm = vta.build(s, [A, B, C],\n",
        "                    target=\"ext_dev\",\n",
        "                    name=\"my_gemm\")\n",
        "\n",
        "# Write the compiled module into an object file.\n",
        "temp = tvm.contrib.utils.tempdir()\n",
        "my_gemm.save(temp.relpath(\"gemm.o\"))\n",
        "\n",
        "# Send the executable over RPC\n",
        "remote.upload(temp.relpath(\"gemm.o\"))\n",
        "\n",
        "# Load the compiled module\n",
        "f = remote.load_module(\"gemm.o\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 运行函数\n",
        "\n",
        "编译后的 TVM 函数使用简洁的 C API，可以从代码语言调用。\n",
        "\n",
        "TVM 在 python 中提供了数组 API 来帮助快速测试和创建原型。数组 API 基于 [DLPac](https://github.com/dmlc/dlpack) 标准。\n",
        "\n",
        "- 首先创建远程上下文（remote context）（用于在 Pynq 上远程执行）。\n",
        "- 然后 {func}`tvm.nd.array` 相应地格式化数据。\n",
        "- {func}`f` 运行实际的计算。\n",
        "- `numpy()` 以可解释的格式将结果数组复制回来。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Get the remote device context\n",
        "ctx = remote.ext_dev(0)\n",
        "\n",
        "# Initialize the A and B arrays randomly in the int range of (-128, 128]\n",
        "A_orig = np.random.randint(-128, 128, size=(o * env.BATCH, n * env.BLOCK_IN)).astype(A.dtype)\n",
        "B_orig = np.random.randint(-128, 128, size=(m * env.BLOCK_OUT, n * env.BLOCK_IN)).astype(B.dtype)\n",
        "\n",
        "# Apply packing to the A and B arrays from a 2D to a 4D packed layout\n",
        "A_packed = A_orig.reshape(o, env.BATCH, n, env.BLOCK_IN).transpose((0, 2, 1, 3))\n",
        "B_packed = B_orig.reshape(m, env.BLOCK_OUT, n, env.BLOCK_IN).transpose((0, 2, 1, 3))\n",
        "\n",
        "# Format the input/output arrays with tvm.nd.array to the DLPack standard\n",
        "A_nd = tvm.nd.array(A_packed, ctx)\n",
        "B_nd = tvm.nd.array(B_packed, ctx)\n",
        "C_nd = tvm.nd.array(np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(C.dtype), ctx)\n",
        "\n",
        "# Clear stats\n",
        "if env.TARGET in [\"sim\", \"tsim\"]:\n",
        "    simulator.clear_stats()\n",
        "\n",
        "# Invoke the module to perform the computation\n",
        "f(A_nd, B_nd, C_nd)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 验证正确性\n",
        "\n",
        "用 numpy 计算参考结果，并断言矩阵乘法的输出确实是正确的："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Execution statistics:\n",
            "\tinp_load_nbytes :              256\n",
            "\twgt_load_nbytes :            65536\n",
            "\tacc_load_nbytes :                0\n",
            "\tuop_load_nbytes :                8\n",
            "\tout_store_nbytes:              256\n",
            "\tgemm_counter    :              256\n",
            "\talu_counter     :                0\n",
            "Successful matrix multiply test!\n"
          ]
        }
      ],
      "source": [
        "# Compute reference result with numpy\n",
        "C_ref = np.dot(A_orig.astype(env.acc_dtype),\n",
        "               B_orig.T.astype(env.acc_dtype)).astype(C.dtype)\n",
        "C_ref = C_ref.reshape(o,\n",
        "                      env.BATCH,\n",
        "                      m,\n",
        "                      env.BLOCK_OUT).transpose((0, 2, 1, 3))\n",
        "np.testing.assert_equal(C_ref, C_nd.numpy())\n",
        "\n",
        "# Print stats\n",
        "if env.TARGET in [\"sim\", \"tsim\"]:\n",
        "    sim_stats = simulator.stats()\n",
        "    print(\"Execution statistics:\")\n",
        "    for k, v in sim_stats.items():\n",
        "        print(f\"\\t{k:<16}: {v:>16}\")\n",
        "\n",
        "print(\"Successful matrix multiply test!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{topic} 小结\n",
        "本教程展示了在 VTA 上实现简单矩阵乘法的 TVM 工作流。\n",
        "\n",
        "一般工作流程包括：\n",
        "\n",
        "- 编程带有 VTA bitstream 的 FPGA 上的 RPC。\n",
        "- 通过一系列计算描述矩阵乘法。\n",
        "- 描述希望如何使用调度原语执行计算。\n",
        "- 编译函数到 VTA 目标。\n",
        "- 运行编译后的模块，并根据 numpy 实现来验证它。\n",
        "```"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
